import numpy as np
import torch

from attacks import Attack
import torch.nn.functional as F
from scipy import stats as st

from constants import DEVICE
from utils import cross_entropy_loss, de_normalization, normalization


class TIM(Attack):
    """ TIM + MI-FGSM """

    def __init__(self, model, eps=16 / 255, steps=10, decay=1.0, kernel_name='gaussian', len_kernel=15, nsig=3):
        """
        :param model: DNN model
        :param eps: the maximum perturbation
        :param steps: the number of iterations
        :param decay: the decay factor
        """
        super().__init__("TIM", model)
        self.eps = eps
        self.steps = steps
        self.alpha = self.eps / self.steps
        self.decay = decay
        self.kernel_name = kernel_name
        self.len_kernel = len_kernel
        self.nsig = nsig
        self.stacked_kernel = torch.from_numpy(self.kernel_generation()).to(DEVICE)

    def kernel_generation(self):
        if self.kernel_name == 'gaussian':
            kernel = self.gkern().astype(np.float32)
        else:
            raise NotImplementedError

        stack_kernel = np.stack([kernel, kernel, kernel])
        stack_kernel = np.expand_dims(stack_kernel, 1)
        return stack_kernel

    def gkern(self):
        """Returns a 2D Gaussian kernel array."""
        x = np.linspace(-self.nsig, self.nsig, self.len_kernel)
        kern1d = st.norm.pdf(x)
        kernel_raw = np.outer(kern1d, kern1d)
        kernel = kernel_raw / kernel_raw.sum()
        return kernel

    def forward(self, images, labels):
        targets = F.one_hot(labels.type(torch.int64), 1000).float().to(DEVICE)
        images_de_normalized = de_normalization(images)
        images_min = torch.clamp(images_de_normalized - self.eps, min=0.0, max=1.0)
        images_max = torch.clamp(images_de_normalized + self.eps, min=0.0, max=1.0)

        adv = images.clone()
        g = torch.zeros_like(images)
        for _ in range(self.steps):
            y_predicts = self.model(adv)
            loss = cross_entropy_loss(y_predicts, targets)
            grad = torch.autograd.grad(loss, adv)[0]
            grad = F.conv2d(grad, self.stacked_kernel, stride=1, padding='same', groups=3)
            g = self.decay * g + grad / torch.mean(torch.abs(grad), dim=(1, 2, 3), keepdim=True)

            adv_de_normalized = de_normalization(adv)
            adv_de_normalized = torch.clamp(adv_de_normalized + self.alpha * torch.sign(g), min=images_min,
                                            max=images_max)
            adv = normalization(adv_de_normalized)

        return adv
